{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Understanding RNN structure\n",
    "- Distinguished from feedforward nets, RNNs are structures that can well handle data with \"sequential\" format by preserving previous \"state\" \n",
    "- Thus, grasping concepts of **\"sequences\"** and (hidden) **\"states\"** in RNNs is crucial\n",
    "\n",
    "<br>\n",
    "<img src=\"http://karpathy.github.io/assets/rnn/charseq.jpeg\" style=\"width: 500px\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from keras.models import Model, Sequential\n",
    "from keras.layers import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. SimpleRNN "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Input shape of SimpleRNN should be 3D tensor => (batch_size, timesteps, input_dim)\n",
    "- **batch_size**: ommitted when creating RNN instance (== None). Usually designated when fitting model.\n",
    "- **timesteps**: number of input sequence per batch\n",
    "- **input_dim**: dimensionality of input sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# for instance, consider below array\n",
    "x = np.array([[\n",
    "             [1,    # => input_dim 1\n",
    "              2,    # => input_dim 2 \n",
    "              3],   # => input_dim 3     # => timestep 1                            \n",
    "             [4, 5, 6]                   # => timestep 2\n",
    "             ],                                  # => batch 1\n",
    "             [[7, 8, 9], [10, 11, 12]],          # => batch 2\n",
    "             [[13, 14, 15], [16, 17, 18]]        # => batch 3\n",
    "             ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Batch size, timesteps, input_dim) =  (3, 2, 3)\n"
     ]
    }
   ],
   "source": [
    "print('(Batch size, timesteps, input_dim) = ',x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rnn = SimpleRNN(50)(Input(shape = (10,))) => error\n",
    "# rnn = SimpleRNN(50)(Input(shape = (10, 30, 40))) => error\n",
    "rnn = SimpleRNN(50)(Input(shape = (10, 30)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**return_state** = **return_sequences** = **False** ====> output_shape = **(batch_size = None, num_units)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "rnn = SimpleRNN(50)(Input(shape = (10, 30)))\n",
    "print(rnn.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**return_sequences = True** ====> output_shape = **(batch_size, timesteps, num_units)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?, 50)\n"
     ]
    }
   ],
   "source": [
    "rnn = SimpleRNN(50, return_sequences = True)(Input(shape = (10, 30)))\n",
    "print(rnn.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "return_state = True ===> outputs list of tensor: **[output, state]**\n",
    "- if return_sequences == False     =>>    output_shape = (batch_size, num_units)\n",
    "- if return_sequences == True      =>>    output_shape = (batch_size, timesteps, num_units)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "rnn = SimpleRNN(50, return_sequences = False, return_state = True)(Input(shape = (10, 30)))\n",
    "print(rnn[0].shape)         # shape of output\n",
    "print(rnn[1].shape)         # shape of last state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "rnn = SimpleRNN(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))\n",
    "print(rnn[0].shape)         # shape of output\n",
    "print(rnn[1].shape)         # shape of last state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Current output and state can be unpacked as below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "output, state = SimpleRNN(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "print(output.shape)\n",
    "print(state.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. LSTM\n",
    "- Outputs of LSTM are quite similar to those of RNNs, but there exist subtle differences\n",
    "- If you compare two diagrams below, there is one more type of \"state\" that is preserved to next module\n",
    "\n",
    "<br>\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-SimpleRNN.png\" style=\"width: 500px\"/>\n",
    "\n",
    "<center> Standard RNN </center>\n",
    "\n",
    "<br>\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png\" style=\"width: 500px\"/>\n",
    "\n",
    "<center> LSTM </center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to \"hidden state (ht)\" in RNN, there exist \"cell state (Ct)\" in LSTM structure\n",
    "\n",
    "<br>\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-o.png\" style=\"width: 500px\"/>\n",
    "\n",
    "<center> Hidden State </center>\n",
    "\n",
    "<br>\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-C.png\" style=\"width: 500px\"/>\n",
    "\n",
    "<center> Cell State </center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lstm = LSTM(50)(Input(shape = (10, 30)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "print(lstm.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, 50)\n",
      "(?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "lstm = LSTM(50, return_sequences = False, return_state = True)(Input(shape = (10, 30)))\n",
    "print(lstm[0].shape)         # shape of output\n",
    "print(lstm[1].shape)         # shape of hidden state\n",
    "print(lstm[2].shape)         # shape of cell state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?, 50)\n",
      "(?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "lstm = LSTM(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))\n",
    "print(lstm[0].shape)         # shape of output\n",
    "print(lstm[1].shape)         # shape of hidden state\n",
    "print(lstm[2].shape)         # shape of cell state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "output, hidden_state, cell_state = LSTM(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?, 50)\n",
      "(?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "print(output.shape)\n",
    "print(hidden_state.shape)\n",
    "print(cell_state.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. GRU\n",
    "- GRU, Popular variant of LSTM, does not have cell state\n",
    "- Hence, it has only hidden state, as simple RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "gru = GRU(50)(Input(shape = (10, 30)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "print(gru.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "gru = GRU(50, return_sequences = False, return_state = True)(Input(shape = (10, 30)))\n",
    "print(gru[0].shape)         # shape of output\n",
    "print(gru[1].shape)         # shape of hidden state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "gru = GRU(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))\n",
    "print(gru[0].shape)         # shape of output\n",
    "print(gru[1].shape)         # shape of hidden state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "output, hidden_state = GRU(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?, 50)\n",
      "(?, 50)\n"
     ]
    }
   ],
   "source": [
    "print(output.shape)\n",
    "print(hidden_state.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
